{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distill Reasoning Data from DeepSeek-R1\n",
    "\n",
    "In the field of LLMs, reasoning models leverage deep thinking capabilities to significantly enhance model performance across complex scenarios. According to the [DeepSeek-R1](https://arxiv.org/abs/2501.12948) paper, the reasoning pattern of larger models can be distilled into smaller models. Specifically, we can distill long-chain-of-thought (long-CoT) data that includes reasoning processes from DeepSeek-R1 and directly fine-tune open-source models like Qwen and Llama. This straightforward distillation method significantly enhances the reasoning abilities of smaller models.\n",
    "\n",
    "\n",
    "To demonstrate the complete distillation process, we have prepared two notebooks that cover how to distill reasoning data from DeepSeek-R1 using the NIM API, and how to train models using the distilled data.\n",
    "\n",
    "\n",
    "- [generate_reasoning_data.ipynb](./generate_reasoning_data.ipynb) (⭐) demonstrates how to distill reasoning data from DeepSeek-R1 using the NIM API. \n",
    "- [qwen2_distill_nemo.ipynb](./qwen2_distill_nemo.ipynb) shows how to train open-source models using the distilled data.\n",
    "\n",
    "\n",
    "This tutorial is part 1 of the series, and it will demonstrate how to distill reasoning data from the DeepSeek-R1 model using NVIDIA NIM.\n",
    "\n",
    "Prerequisites:\n",
    "- Obtain an NVIDIA API Key (visit [build.nvidia.com](https://build.nvidia.com/explore/discover) for details)\n",
    "\n",
    "This notebook contains three steps:\n",
    "1. Prepare the raw dataset\n",
    "2. Distill reasoning data from DeepSeek-R1 using NVIDIA NIM API\n",
    "3. Post-process the distilled data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install openai math_verify datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: NVIDIA_API_KEY=nvapi-******\n"
     ]
    }
   ],
   "source": [
    "%env NVIDIA_API_KEY=nvapi-******"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Prepare Dataset\n",
    "\n",
    "During the training process of DeepSeek-R1-Zero, DeepSeek mentioned they used data from math, code, science, and logic domains. However, since they haven't disclosed the specific data sources, we will use open-source datasets as examples.\n",
    "\n",
    "In the following code, we will use the [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k) from HuggingFace. \n",
    "\n",
    "You can also create your own dataset, but it's best to align with the example dataset's format, ensuring each entry contains both a `question` and an `answer`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating default split: 100%|██████████| 93733/93733 [00:05<00:00, 17333.48 examples/s]\n",
      "Generating extended split: 100%|██████████| 131396/131396 [00:04<00:00, 28622.32 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset size: 93733\n",
      "===== Problem 1 =====\n",
      "## Task B-1.3.\n",
      "\n",
      "A ship traveling along a river has covered $24 \\mathrm{~km}$ upstream and $28 \\mathrm{~km}$ downstream. For this journey, it took half an hour less than for traveling $30 \\mathrm{~km}$ upstream and $21 \\mathrm{~km}$ downstream, or half an hour more than for traveling $15 \\mathrm{~km}$ upstream and $42 \\mathrm{~km}$ downstream, assuming that both the ship and the river move uniformly.\n",
      "\n",
      "Determine the speed of the ship in still water and the speed of the river.\n",
      "===== Answer 1 =====\n",
      "v_{R}=4\\mathrm{~}/\\mathrm{},v_{B}=10\\mathrm{~}/\\mathrm{}\n",
      "\n",
      "\n",
      "===== Problem 2 =====\n",
      "3. (6 points) A construction company was building a tunnel. When $\\frac{1}{3}$ of the tunnel was completed at the original speed, they started using new equipment, which increased the construction speed by $20 \\%$ and reduced the working hours to $80 \\%$ of the original. As a result, it took a total of 185 days to complete the tunnel. If they had not used the new equipment and continued at the original speed, it would have taken $\\qquad$ days to complete the tunnel.\n",
      "===== Answer 2 =====\n",
      "180\n",
      "\n",
      "\n",
      "===== Problem 3 =====\n",
      "Prove that number $1$ can be represented as a sum of a finite number $n$ of real numbers, less than $1,$ not necessarily  distinct, which contain in their decimal representation only the digits $0$ and/or $7.$ Which is the least possible number $n$?\n",
      "===== Answer 3 =====\n",
      " 8 \n",
      "\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"open-r1/OpenR1-Math-220k\", split=\"train\")\n",
    "\n",
    "print(f\"Dataset size: {len(dataset)}\")\n",
    "\n",
    "# Print the first few examples\n",
    "for i, example in enumerate(dataset.select(range(3))):\n",
    "    print(f\"===== Problem {i+1} =====\")\n",
    "    print(example[\"problem\"])\n",
    "    print(f\"===== Answer {i+1} =====\")\n",
    "    print(example[\"answer\"])\n",
    "    print(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Distill Reasoning Data from DeepSeek-R1 Using NVIDIA NIM API\n",
    "\n",
    "DeepSeek recommends adhering to the following configurations when running inference the DeepSeek-R1 series of models, including benchmarking, to achieve the expected performance:\n",
    "\n",
    "- Set the temperature within the range of 0.5-0.7 (0.6 is recommended) to prevent endless repetitions or incoherent outputs.\n",
    "- Avoid adding a system prompt; all instructions should be contained within the user prompt.\n",
    "- For mathematical problems, it is advisable to include a directive in your prompt such as: \"Please reason step by step, and put your final answer within \\boxed{}.\""
   ]
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "This cell configures the NIM client and runs a basic distillation test. "
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from openai import OpenAI\n",
    "\n",
    "client = OpenAI(\n",
    "  base_url = \"https://integrate.api.nvidia.com/v1\",\n",
    "  api_key = os.getenv(\"NVIDIA_API_KEY\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<think>\n",
      "Okay, so I need to figure out which number is larger between 9.11 and 9.8. Let me start by writing them down to compare more easily. \n",
      "\n",
      "First, both numbers have the same whole number part, which is 9. That means the difference must be in the decimal parts. The first number is 9.11, and the second is 9.8. Hmm, let's break them down. \n",
      "\n",
      "For 9.11, the decimal part is 0.11, and for 9.8, the decimal part is 0.8. But wait, 0.8 is the same as 0.80, right? So if I compare 0.80 and 0.11, clearly 0.80 is larger. Therefore, 9.8 should be larger than 9.11. \n",
      "\n",
      "But let me double-check to make sure I'm not making a mistake here. Sometimes when decimals have different numbers of digits, it can be confusing. Let's line them up by their decimal places:\n",
      "\n",
      "9.11\n",
      "9.80\n",
      "\n",
      "Now, comparing each digit after the decimal point one by one. The first digit after the decimal is 1 in 9.11 and 8 in 9.80. Since 8 is greater than 1, that means 9.80 is larger right there. The rest of the digits don't matter because the first decimal place already determines which is bigger. \n",
      "\n",
      "Alternatively, I can convert them both to fractions to compare. Let's see. \n",
      "\n",
      "9.11 is the same as 9 + 11/100, which is 9 + 11/100 = 911/100. \n",
      "\n",
      "9.8 is the same as 9 + 8/10, which is 9 + 80/100 = 980/100. \n",
      "\n",
      "Now, comparing 911/100 and 980/100, since the denominators are the same, we just look at the numerators. 980 is greater than 911, so 980/100 (which is 9.8) is larger. \n",
      "\n",
      "Another way to think about it is to subtract one from the other. Let's subtract 9.11 from 9.8. \n",
      "\n",
      "9.80\n",
      "-9.11\n",
      "------\n",
      "0.69\n",
      "\n",
      "The result is 0.69, which is a positive number. That means 9.8 is 0.69 more than 9.11, so 9.8 is indeed larger. \n",
      "\n",
      "Wait, but just to be thorough, maybe I should visualize them on a number line. If I imagine the numbers between 9 and 10, 9.11 is just a little past 9.1, whereas 9.8 is much closer to 10. So clearly, 9.8 is further to the right on the number line, making it the larger number. \n",
      "\n",
      "Is there any chance that I might have confused 9.8 with 9.08? If 9.8 were written as 9.08, then 9.11 would be larger. But in this case, the original numbers are 9.11 and 9.8, so 9.8 is definitely 9.80, not 9.08. \n",
      "\n",
      "Also, considering place value: the tenths place in 9.11 is 1, and in 9.8 it's 8. Since the tenths place is more significant than the hundredths or thousandths places, the number with the higher tenths digit is larger, regardless of the subsequent digits. \n",
      "\n",
      "So all methods point to 9.8 being larger than 9.11. I think that's solid. No mistakes in my reasoning.\n",
      "\n",
      "**Final Answer**\n",
      "The larger number is \\boxed{9.8}.\n",
      "</think>\n",
      "\n",
      "To determine which number is larger between 9.11 and 9.8, we start by comparing their whole number parts, which are both 9. Therefore, we compare the decimal parts.\n",
      "\n",
      "1. Convert 9.8 to 9.80 to have the same number of decimal places as 9.11:\n",
      "   \\[\n",
      "   9.11 \\quad \\text{and} \\quad 9.80\n",
      "   \\]\n",
      "2. Compare the tenths place: 1 (from 9.11) vs. 8 (from 9.80). Since 8 > 1, 9.80 is larger.\n",
      "3. Converting both numbers to fractions confirms this:\n",
      "   \\[\n",
      "   9.11 = \\frac{911}{100} \\quad \\text{and} \\quad 9.8 = \\frac{980}{100}\n",
      "   \\]\n",
      "   Since 980 > 911, \\(\\frac{980}{100} > \\frac{911}{100}\\).\n",
      "4. Subtracting 9.11 from 9.8 gives a positive result:\n",
      "   \\[\n",
      "   9.80 - 9.11 = 0.69\n",
      "   \\]\n",
      "5. Visualizing on a number line, 9.8 is closer to 10, confirming it is larger.\n",
      "\n",
      "Thus, the larger number is \\boxed{9.8}."
     ]
    }
   ],
   "source": [
    "# A simple test case\n",
    "problem = \"which number is larger, 9.11 or 9.8?\"\n",
    "completion = client.chat.completions.create(\n",
    "    model=\"deepseek-ai/deepseek-r1\",\n",
    "    messages=[{\"role\": \"user\", \"content\": f\"Please reason step by step, and put your final answer within \\\\boxed{{}}. {problem}\"}],\n",
    "    temperature=0.6,\n",
    "    top_p=0.7,\n",
    "    max_tokens=32768,\n",
    "    timeout=1000,\n",
    "    stream=True\n",
    ")\n",
    "\n",
    "for chunk in completion:\n",
    "    if chunk.choices[0].delta.content is not None:\n",
    "        print(chunk.choices[0].delta.content, end=\"\")"
   ]
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Now, we're ready to generate reasoning traces using DeepSeek-R1 for the entire dataset."
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The prompt template recommended by DeepSeek for math problems\n",
    "PROMPT_TEMPLATE = \"Please reason step by step, and put your final answer within \\\\boxed{{}}. {problem}\"\n",
    "\n",
    "def process_streaming_response(completion):\n",
    "    \"\"\"Process the streaming response from the R1 model\"\"\"\n",
    "    reasoning_trace = \"\"\n",
    "    try:\n",
    "        for chunk in completion:\n",
    "            if chunk.choices[0].delta.content is not None:\n",
    "                reasoning_trace += chunk.choices[0].delta.content\n",
    "        return reasoning_trace\n",
    "    except Exception as e:\n",
    "        print(f\"Error occurred: {e}\")\n",
    "        return reasoning_trace\n",
    "\n",
    "def distill_data_from_r1(example):\n",
    "    problem = example[\"problem\"]\n",
    "    completion = client.chat.completions.create(\n",
    "        model=\"nvdev/deepseek-ai/deepseek-r1\",\n",
    "        messages=[{\"role\": \"user\", \"content\": PROMPT_TEMPLATE.format(problem=problem)}],\n",
    "        temperature=0.6,\n",
    "        top_p=0.7,\n",
    "        max_tokens=32768,\n",
    "        timeout=10000,\n",
    "        stream=True\n",
    "    )\n",
    "    \n",
    "    reasoning_trace = process_streaming_response(completion)\n",
    "    return {**example, \"reasoning_trace\": reasoning_trace}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Distilling reasoning traces from R1:  50%|█████     | 1/2 [05:10<05:10, 310.33s/ examples]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error occurred: peer closed connection without sending complete message body (incomplete chunked read)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Distilling reasoning traces from R1: 100%|██████████| 2/2 [10:20<00:00, 310.33s/ examples]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error occurred: peer closed connection without sending complete message body (incomplete chunked read)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# To speed up the process, we only use 2 examples here\n",
    "sample_dataset = dataset.select(range(2))\n",
    "\n",
    "# You can set num_proc to speed up the process\n",
    "sample_dataset = sample_dataset.map(distill_data_from_r1, num_proc=1, desc=\"Distilling reasoning traces from R1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['<think>\\nOkay, let\\'s see. So, the problem is about a ship traveling upstream and downstream on a river. We need to find the speed of the ship in still water and the speed of the river. Hmm, classic upstream/downstream problem. These usually involve the concept that when going upstream, the effective speed is the ship\\'s speed minus the river\\'s speed, and downstream it\\'s the ship\\'s speed plus the river\\'s speed. Let me write that down.\\n\\nLet’s denote:\\n- \\\\( v \\\\) as the speed of the ship in still water (km/h)\\n- \\\\( u \\\\) as the speed of the river current (km/h)\\n\\nSo, upstream speed would be \\\\( v - u \\\\), and downstream speed would be \\\\( v + u \\\\).\\n\\nThe problem mentions three different journeys:\\n\\n1. First journey: 24 km upstream and 28 km downstream. The time taken for this is half an hour less than the second journey.\\n2. Second journey: 30 km upstream and 21 km downstream. The time for this is half an hour more than the first journey. Wait, actually, the problem states: \"For this journey, it took half an hour less than for traveling 30 km upstream and 21 km downstream, or half an hour more than for traveling 15 km upstream and 42 km downstream.\" Hmm, that wording is a bit confusing. Let me parse it again.\\n\\nOriginal problem statement: \"A ship traveling along a river has covered 24 km upstream and 28 km downstream. For this journey, it took half an hour less than for traveling 30 km upstream and 21 km downstream, or half an hour more than for traveling 15 km upstream and 42 km downstream, assuming that both the ship and the river move uniformly.\"\\n\\nWait, so the first journey (24 up, 28 down) took half an hour less than the second journey (30 up, 21 down). Alternatively, the same first journey took half an hour more than the third journey (15 up, 42 down). So, we have two different comparisons here. Let me structure this.\\n\\nLet’s define:\\n\\n- Journey 1: 24 km upstream and 28 km downstream. Time = T1\\n- Journey 2: 30 km upstream and 21 km downstream. Time = T2\\n- Journey 3: 15 km upstream and 42 km downstream. Time = T3\\n\\nGiven that T1 = T2 - 0.5 hours (since it took half an hour less than Journey 2)\\nAnd also T1 = T3 + 0.5 hours (since it took half an hour more than Journey 3)\\n\\nSo, we have two equations here:\\n\\n1. T1 = T2 - 0.5\\n2. T1 = T3 + 0.5\\n\\nTherefore, T2 - 0.5 = T3 + 0.5 => T2 = T3 + 1\\n\\nBut maybe we don\\'t need that relation. Let\\'s see. Let\\'s express each time in terms of the speeds.\\n\\nTime = Distance / Speed\\n\\nSo, for Journey 1:\\nT1 = (24 / (v - u)) + (28 / (v + u))\\n\\nFor Journey 2:\\nT2 = (30 / (v - u)) + (21 / (v + u))\\n\\nFor Journey 3:\\nT3 = (15 / (v - u)) + (42 / (v + u))\\n\\nGiven that T1 = T2 - 0.5 and T1 = T3 + 0.5. So, substituting the expressions:\\n\\nFirst equation:\\n(24 / (v - u)) + (28 / (v + u)) = (30 / (v - u)) + (21 / (v + u)) - 0.5\\n\\nSecond equation:\\n(24 / (v - u)) + (28 / (v + u)) = (15 / (v - u)) + (42 / (v + u)) + 0.5\\n\\nSo, now we have two equations with two variables, \\\\( v \\\\) and \\\\( u \\\\). Let\\'s simplify both equations.\\n\\nStarting with the first equation:\\n\\nLeft side: 24/(v - u) + 28/(v + u)\\n\\nRight side: 30/(v - u) + 21/(v + u) - 0.5\\n\\nSubtracting the right side from both sides to bring everything to the left:\\n\\n24/(v - u) + 28/(v + u) - 30/(v - u) - 21/(v + u) + 0.5 = 0\\n\\nSimplify terms:\\n\\n(24 - 30)/(v - u) + (28 - 21)/(v + u) + 0.5 = 0\\n\\nWhich is:\\n\\n(-6)/(v - u) + 7/(v + u) + 0.5 = 0\\n\\nSimilarly, for the second equation:\\n\\nLeft side: 24/(v - u) + 28/(v + u)\\n\\nRight side: 15/(v - u) + 42/(v + u) + 0.5\\n\\nBringing everything to the left:\\n\\n24/(v - u) + 28/(v + u) - 15/(v - u) - 42/(v + u) - 0.5 = 0\\n\\nSimplify terms:\\n\\n(24 - 15)/(v - u) + (28 - 42)/(v + u) - 0.5 = 0\\n\\nWhich is:\\n\\n9/(v - u) - 14/(v + u) - 0.5 = 0\\n\\nSo now, we have two equations:\\n\\n1. (-6)/(v - u) + 7/(v + u) + 0.5 = 0\\n2. 9/(v - u) - 14/(v + u) - 0.5 = 0\\n\\nLet me denote \\\\( x = 1/(',\n",
       " '<think>\\nOkay, let\\'s see. The problem is about a construction company building a tunnel. They first built 1/3 of the tunnel at the original speed, then started using new equipment which increased their speed by 20% and reduced working hours to 80% of the original. The total time taken was 185 days. We need to find how many days it would have taken if they continued at the original speed without the new equipment.\\n\\nHmm. Let me break this down step by step.\\n\\nFirst, let\\'s define some variables. Let\\'s say the total length of the tunnel is 1 unit (since we\\'re dealing with fractions, this might make calculations easier). Let the original construction speed be \\'v\\' units per day. The original working hours per day would be \\'h\\' hours. But wait, the problem mentions that with the new equipment, the working hours were reduced to 80% of the original. So maybe we need to consider the relationship between speed, working hours, and the actual construction rate.\\n\\nWait, speed here is probably the rate of construction, which could be affected by both the efficiency of the equipment and the number of hours worked each day. So if they increased their construction speed by 20%, that could be due to more efficient equipment, but they also reduced the working hours to 80% of the original. So the overall effect on the daily construction rate would be a combination of these two factors.\\n\\nLet me clarify. Let\\'s denote:\\n\\n- Original construction speed: v (units per hour)\\n- Original working hours per day: h\\n- Therefore, original daily construction rate: v * h (units per day)\\n\\nWith the new equipment:\\n\\n- Construction speed increased by 20%, so new speed = 1.2v (units per hour)\\n- Working hours reduced to 80% of original, so new working hours = 0.8h (hours per day)\\n- Therefore, new daily construction rate = 1.2v * 0.8h = 0.96vh (units per day)\\n\\nWait, so even though the speed increased by 20%, the working hours decreased by 20%, resulting in a net daily rate of 96% of the original? That seems like a 4% decrease in daily rate. Hmm, is that correct?\\n\\nWait, 20% increase in speed is multiplicative, and 20% decrease in hours is also multiplicative. So 1.2 * 0.8 = 0.96. Yes, that\\'s correct. So the daily construction rate actually decreased by 4%? That seems counterintuitive. If they work faster but fewer hours, the net effect is a slight decrease in daily progress. Interesting.\\n\\nBut maybe the problem is phrased differently. Let me check again.\\n\\nThe problem says: \"increased the construction speed by 20% and reduced the working hours to 80% of the original\". So speed per hour is increased by 20%, and working hours per day is reduced to 80%. So yes, the daily rate would be 1.2 * 0.8 = 0.96 times the original. So their daily progress is actually 96% of what it was before. Wait, that\\'s a 4% decrease. So even though they are working faster, because they are working fewer hours each day, their overall daily progress is slightly less. So that\\'s a key point.\\n\\nTherefore, when they started using the new equipment, their daily rate became 0.96vh, which is 0.96 times the original daily rate.\\n\\nBut maybe I need to think in terms of time taken for the remaining portion.\\n\\nLet me approach this step by step.\\n\\nLet’s denote:\\n\\n- Total tunnel length: 1 unit (for simplicity)\\n- Original construction speed: v units per day\\n- Original total time to complete the tunnel: T days. Therefore, v = 1 / T. Because if they build at speed v, then total time is 1 / v. Wait, actually, speed is units per day, so time is total units divided by speed. So if original speed is v, then original time T = 1 / v. Alternatively, if they take T days at speed v, then v = 1 / T.\\n\\nBut maybe another way. Let me think.\\n\\nSuppose without the new equipment, the total time would be T days. So the original speed is 1/T units per day.\\n\\nBut when they started using the new equipment after completing 1/3 of the tunnel, their speed changed. So the first 1/3 was done at the original speed, and the remaining 2/3 was done at the new speed, which is 20% higher, but working hours reduced to 80%, leading to a daily rate of 0.96 times the original.\\n\\nWait, but is the speed per hour increased by 20%, and hours per day reduced to 80%, so total daily output is 1.2 * 0.8 = 0.96 times original. So if original daily output was 1/T per day (since total is 1 unit), then after the change, daily output is 0.96/T per day.\\n\\nBut maybe that\\'s not the right way to model it. Let me try again.\\n\\nLet’s suppose that originally, the construction speed is such that they can complete the tunnel in T days. So their original rate is 1/T per day.\\n\\nBut when they switch to the new equipment, their rate changes. The rate is affected by two factors: a 20% increase in speed and a 20% decrease in working hours. So the new rate would be original speed * 1.2 (due to the speed increase) multiplied by 0.8 (due to fewer hours). So 1.2 * 0.8 = 0.96. Therefore, the new rate']"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample_dataset['reasoning_trace']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Post-Process Distilled Data\n",
    "\n",
    "After generating data, we should filter out any low-quality reasoning data. We can establish some filtering rules, such as:\n",
    "- Whether the language in the reasoning trace meets requirements\n",
    "- Whether the reasoning trace format is correct, i.e., wrapping the thinking process in `<think></think>` tags before giving the final answer\n",
    "- Whether the answer given in the reasoning trace is correct\n",
    "- Other filtering rules mentioned in the R1 paper\n",
    "    - Long paragraphs\n",
    "    - Containing Code blocks\n",
    "\n",
    "In this tutorial, we will only verify the format and the correctness of the answers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from latex2sympy2_extended import NormalizationConfig\n",
    "from math_verify import LatexExtractionConfig, parse, verify\n",
    "\n",
    "\n",
    "def check_format(reasoning_trace):\n",
    "    pattern = r\"^<think>.*?</think>\"\n",
    "    if not re.match(pattern, reasoning_trace, re.DOTALL | re.MULTILINE):\n",
    "        return False\n",
    "    # check if all tags only appear once\n",
    "    tags = [\"<think>\", \"</think>\"]\n",
    "    for tag in tags:\n",
    "        if reasoning_trace.count(tag) != 1:\n",
    "            return False\n",
    "    return True\n",
    "\n",
    "# We use math_verify to check if the answer is mathematically equivalent to the ground truth\n",
    "def calculate_answer(reasoning_trace, ground_truth):\n",
    "    \"\"\"Check if the answer is the same as the ground truth.\"\"\"\n",
    "    answer_parsed = parse(\n",
    "        reasoning_trace,\n",
    "        extraction_config=[\n",
    "            LatexExtractionConfig(\n",
    "                normalization_config=NormalizationConfig(\n",
    "                    nits=False,\n",
    "                    malformed_operators=False,\n",
    "                    basic_latex=True,\n",
    "                    equations=True,\n",
    "                    boxed=True,\n",
    "                    units=True,\n",
    "                ),\n",
    "                # Ensures that boxed is tried first\n",
    "                boxed_match_priority=0,\n",
    "                try_extract_without_anchor=False,\n",
    "            )\n",
    "        ],\n",
    "        extraction_mode=\"first_match\",\n",
    "    )\n",
    "\n",
    "    return verify(answer_parsed, ground_truth)\n",
    "\n",
    "def filter_reasoning_trace(example):\n",
    "    reasoning_trace = example[\"reasoning_trace\"]\n",
    "    ground_truth = example[\"answer\"]\n",
    "    if not check_format(reasoning_trace):\n",
    "        return {**example, \"filtered\": True, \"filtered_reason\": \"INVALID_FORMAT\"}\n",
    "    if not calculate_answer(reasoning_trace, ground_truth):\n",
    "        return {**example, \"filtered\": True, \"filtered_reason\": \"INCORRECT_ANSWER\"}\n",
    "    return {**example, \"filtered\": False, \"filtered_reason\": \"VALID\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Filtering reasoning traces: 100%|██████████| 2/2 [00:00<00:00, 159.56 examples/s]\n",
      "Filter: 100%|██████████| 2/2 [00:00<00:00, 957.49 examples/s]\n",
      "Saving the dataset (1/1 shards): : 0 examples [00:00, ? examples/s]\n"
     ]
    }
   ],
   "source": [
    "sample_dataset = sample_dataset.map(filter_reasoning_trace, desc=\"Filtering reasoning traces\")\n",
    "\n",
    "# filter out the invalid reasoning traces\n",
    "filtered_dataset = sample_dataset.filter(lambda x: not x[\"filtered\"])\n",
    "\n",
    "# save the filtered dataset\n",
    "filtered_dataset.save_to_disk(\"filtered_dataset\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next Steps\n",
    "\n",
    "Due to the randomness of the reasoning process, we can run the above process multiple times to generate multiple reasoning traces for each question. Then, we can apply quality filtering to construct the distilled dataset.\n",
    "\n",
    "After collecting the distilled dataset, you can refer to [the qwen2 distillation notebook](./qwen2_distill_nemo.ipynb) to train your model using this dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
